2023.08.01 pytorch_geo基礎【torch_geometric】
GNNの畳み込み層を利用して空手クラブのネットワークを学習する。
学習済みのモデルに対して
必要なライブラリを読み込み
code:python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.datasets import KarateClub
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
import networkx as nx
from matplotlib import pyplot as plt
import numpy as np
関数とクラスの定義
code:python
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
hidden_size = 5
self.conv1 = GCNConv(dataset.num_node_features, hidden_size)
self.conv2 = GCNConv(hidden_size, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F. relu(x)
x = self.conv2(x, edge_index)
return F.softmax(x, dim=1)
def check_graph(data):
print(data,'グラフ構造')
print(data.keys, 'グラフのキー')
print(data.num_nodes, 'ノード数')
print(data.num_node_features, 'ノードの特徴量数')
print(data.num_edges, 'エッジ数')
print(data.has_isolated_nodes(), '孤立ノードの有無')
print(data.has_self_loops(), '自己ループの有無')
for key in data.keys:
print('## ', key)
print(datakey)
データセットの読み込み
code:python
dataset = KarateClub()
print(len(dataset), '空手クラブデータセットが有するデータの数') # 1
print(type(dataset), 'データセットの確認')
print(dataset.num_classes, 'クラス数') # 4
print(dataset, 'datasetを表示')
data = dataset0
print(type(data), '取り出したデータセットの確認')
print(data, '取り出したデータセットの確認これで具体的なモノが見えているようだ')
data.y_onehot = F.one_hot(data.y, num_classes=4).float()
check_graph(data)
# ノード数34, エッジ数156, 孤立無し、自己ループ無し
# クラス数4
ここで抽出した変数「data」にネットワーク構造が格納されている。
code:text
Data(x=34, 34, edge_index=2, 156, y=34, train_mask=34, y_onehot=34, 4) グラフ構造
'y_onehot', 'edge_index', 'y', 'train_mask', 'x' グラフのキー
34 ノード数
34 ノードの特徴量数
156 エッジ数
このことから、生成されるNNは、
1層目
GCNConv(dataset.num_node_features, hidden_size)
num_node_features : 34 ... おそらく'x'の横の長さ
hidden_size : インスタンス生成時に与えたハイパーパラメータ、ここでは「5」
2層目
GCNConv(hidden_size, dataset.num_classes)
hidden_size : 入力層から「5」
num_classes : クラスの数「4」、data.yに格納されている値からこのデータセットは4クラス分類用である。
Networkxを用いた可視化
code:python
nxg = to_networkx(data)
pr = nx.pagerank(nxg)
pr_max = np.array(list(pr.values())).max()
draw_pos = nx.spring_layout(nxg, seed=0)
cmap = plt.get_cmap('tab10')
labels = data.y.numpy()
colors = cmap(l) for l in labels
nx.draw_networkx_nodes(
nxg,
draw_pos,
node_size = v/pr_max * 1000 for v in pr.values(),
node_color = colors,
alpha = 0.5
)
nx.draw_networkx_edges(nxg, draw_pos, arrowstyle='-', alpha=0.2)
nx.draw_networkx_labels(nxg, draw_pos, font_size=10)
plt.title('KarateClub')
plt.show()
ここからmain処理
code:python
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 1000
epochk = int(i) for i in np.linspace(1, epochs, 11)
loss_hist = []
loss_func = torch.nn.MSELoss()
# loss_func = torch.nn.NLLLoss() # 収束するけど、onehotではなくdata.yを要求するのが?
for epoch in range(epochs):
optimizer.zero_grad()
out = model(data) # 同じネットワークを何度も入力し、
# out.shape は torch.Size(34, 4)
loss = loss_func(out, data.y_onehot)
# loss = loss_func(out, data.y)
loss.backward()
optimizer.step()
loss_hist.append(loss.item())
if epochk.count(epoch) != 0:
print('# ', (epoch*100 // epochs) + 10, '%', end=' ')
print('Epoch %d, Loss %.4f ' % (epoch, loss.item()))
plt.plot(loss_hist)
パラメータを確認
code:python
params = list(model.parameters())
print(len(params))
for param in params:
print(type(param), param.shape)
結果は
code:text
<class 'torch.nn.parameter.Parameter'> torch.Size(5)
<class 'torch.nn.parameter.Parameter'> torch.Size(5, 34)
<class 'torch.nn.parameter.Parameter'> torch.Size(4)
<class 'torch.nn.parameter.Parameter'> torch.Size(4, 5)